Improve AMP stability#60
Conversation
| # --- 2. Sharded Dice Loss --- | ||
| mask_pred_probs = F.softmax(local_preds.float(), dim=1) | ||
| mask_true_onehot = ( | ||
| F.one_hot(local_labels, n_categories + 1) | ||
| .permute(0, 4, 1, 2, 3) | ||
| .float() | ||
| ) | ||
|
|
||
| # Dice loss uses probabilities | ||
| dice_score_probs = compute_sharded_dice( | ||
| mask_pred_probs, mask_true_onehot, spatial_mesh | ||
| ) |
There was a problem hiding this comment.
It looks like this got inserted in the middle of the CE loss calc. Can you move it back to being after CE_loss = ...? This should also shrink the diff + make it more clear what the actual changes are here (not casting local_preds to float)
|
|
||
| # Set up gradient scaler for AMP (Automatic Mixed Precision) | ||
| self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp) | ||
| self.use_grad_scaler = self.config.torch_amp and self.amp_dtype == torch.float16 |
There was a problem hiding this comment.
Is the and self.amp_dtype == torch.float16 basically just catching the case where we're NOT running with bf16? Would it be better to write that explicitly, like and self.amp_dtype != torch.bfloat16?
There was a problem hiding this comment.
Yes basically the options are bf16 or f16. I don't think bf/f8 would work. I am ok with making this change
| # 2. Sharded Dice Loss | ||
| local_preds_softmax = F.softmax(local_preds.float(), dim=1) | ||
| local_labels_one_hot = ( | ||
| F.one_hot( | ||
| local_labels, num_classes=self.config.n_categories + 1 | ||
| ) | ||
| .permute(0, 4, 1, 2, 3) | ||
| .float() | ||
| ) | ||
| dice_scores = compute_sharded_dice( | ||
| local_preds_softmax, local_labels_one_hot, self.spatial_mesh | ||
| ) |
There was a problem hiding this comment.
Same as in evaluate.py, I think this should stay after global_ce_sum = ...
| # Sharded Dice Loss | ||
| local_preds_softmax = F.softmax( | ||
| local_preds.float(), dim=1 | ||
| ) | ||
| local_labels_one_hot = ( | ||
| F.one_hot( | ||
| local_labels, | ||
| num_classes=self.config.n_categories + 1, | ||
| ) | ||
| .permute(0, 4, 1, 2, 3) | ||
| .float() | ||
| ) | ||
| # Compute sharded dice | ||
| dice_scores = compute_sharded_dice( | ||
| local_preds_softmax, | ||
| local_labels_one_hot, | ||
| self.spatial_mesh, | ||
| ) |
There was a problem hiding this comment.
Same as evaluate and warmup, this should come after global_ce_sum = ...
val_dice_score=nanScaFFold/utils/data_types.pyfor testing`.